{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from scipy import stats\n",
    "from sklearn.metrics import jaccard_score, accuracy_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "stance = pd.read_csv('data/stance_test_labeled.csv')\n",
    "sent = pd.read_csv('data/sentiment_test_labeled.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def acc_statistic(x1, x2, y):\n",
    "    \"\"\"\n",
    "    Calculate Matthews Correlation Coefficient as the test statistic.\n",
    "    Assumes x and y are binary (0 or 1) arrays.\n",
    "    \"\"\"\n",
    "    f11 = accuracy_score(y, x1)\n",
    "    f12 = accuracy_score(y, x2)\n",
    "    return f11 - f12\n",
    "\n",
    "def run_permutation_test_acc(docs, true, labels1, labels2, n_permutations=10000, alternative = 'greater', random_seed = 1):\n",
    "    # Extract the data from the two columns\n",
    "    x1 = docs[labels1].values\n",
    "    x2 = docs[labels2].values\n",
    "    y = docs[true].values\n",
    "    \n",
    "    # Run the permutation test\n",
    "    result = stats.permutation_test((x1, x2, y), acc_statistic, \n",
    "                                    n_resamples=n_permutations, \n",
    "                                    alternative=alternative,\n",
    "                                    random_state = random_seed)\n",
    "    \n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Stance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.19418058194180582\n",
      "0.2586741325867413\n",
      "0.8533146685331466\n",
      "SignificanceResult(statistic=6.299559612288608, pvalue=0.3904832560378262)\n",
      "SignificanceResult(statistic=0.26516287717105, pvalue=0.3954419956414165)\n",
      "0.19418058194180582\n",
      "0.2586741325867413\n",
      "0.8533146685331466\n",
      "SignificanceResult(statistic=6.299559612288608, pvalue=0.3904832560378262)\n",
      "SignificanceResult(statistic=0.26516287717105, pvalue=0.3954419956414165)\n"
     ]
    }
   ],
   "source": [
    "gpt4val = run_permutation_test_acc(stance, 'stance', 'gpt4_val', 'gpt4_sent', alternative = 'less')\n",
    "sonnetval = run_permutation_test_acc(stance, 'stance', 'sonnet_val', 'sonnet_sent', alternative = 'less')\n",
    "llamaval = run_permutation_test_acc(stance, 'stance', 'llama_val', 'llama_sent', alternative = 'less')\n",
    "\n",
    "print(gpt4val.pvalue)\n",
    "print(sonnetval.pvalue)\n",
    "print(llamaval.pvalue)\n",
    "\n",
    "print(stats.combine_pvalues([gpt4val.pvalue, sonnetval.pvalue, llamaval.pvalue], method = 'fisher'))\n",
    "print(stats.combine_pvalues([gpt4val.pvalue, sonnetval.pvalue, llamaval.pvalue], method = 'stouffer'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Valence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.2685731426857314\n",
      "0.37266273372662734\n",
      "0.22507749225077492\n",
      "SignificanceResult(statistic=7.586047966920749, pvalue=0.27002528519145086)\n",
      "SignificanceResult(statistic=0.9798211280376468, pvalue=0.16358721049497793)\n",
      "0.2685731426857314\n",
      "0.37266273372662734\n",
      "0.22507749225077492\n",
      "SignificanceResult(statistic=7.586047966920749, pvalue=0.27002528519145086)\n",
      "SignificanceResult(statistic=0.9798211280376468, pvalue=0.16358721049497793)\n"
     ]
    }
   ],
   "source": [
    "# llama 3 Stance vs. sentiment prompt on stance classification\n",
    "gpt4val = run_permutation_test_acc(sent, 'sentiment', 'gpt4_sent', 'gpt4_val', alternative = 'greater')\n",
    "sonnetval = run_permutation_test_acc(sent, 'sentiment', 'sonnet_sent', 'sonnet_val', alternative = 'greater')\n",
    "llamaval = run_permutation_test_acc(sent, 'sentiment', 'llama_sent', 'llama_val', alternative = 'less')\n",
    "\n",
    "print(gpt4val.pvalue)\n",
    "print(sonnetval.pvalue)\n",
    "print(llamaval.pvalue)\n",
    "\n",
    "print(stats.combine_pvalues([gpt4val.pvalue, sonnetval.pvalue, llamaval.pvalue], method = 'fisher'))\n",
    "print(stats.combine_pvalues([gpt4val.pvalue, sonnetval.pvalue, llamaval.pvalue], method = 'stouffer'))"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
